# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
# Modified work Copyright (c) 2025 VAST-AI-Research and contributors.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

import torch
from .tables import *

__all__ = [
    'FlexiCubes'
]


class FlexiCubes:
    def __init__(self, device="cuda"):

        self.device = device
        self.dmc_table = torch.tensor(dmc_table, dtype=torch.int32, device=device, requires_grad=False)
        self.num_vd_table = torch.tensor(num_vd_table,
                                         dtype=torch.int32, device=device, requires_grad=False)
        self.check_table = torch.tensor(
            check_table,
            dtype=torch.int32, device=device, requires_grad=False)

        self.tet_table = torch.tensor(tet_table, dtype=torch.int32, device=device, requires_grad=False)
        self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.int32, device=device, requires_grad=False)
        self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.int32, device=device, requires_grad=False)
        self.quad_split_train = torch.tensor(
            [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.int32, device=device, requires_grad=False)

        self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
                                         1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
        self.cube_corners_idx = torch.pow(2, torch.arange(8, dtype=torch.int32, requires_grad=False))
        self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
                                       2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.int32, device=device, requires_grad=False)

        self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
                                           dtype=torch.int32, device=device)
        self.dir_faces_table = torch.tensor([
            [[5, 4], [3, 2], [4, 5], [2, 3]],
            [[5, 4], [1, 0], [4, 5], [0, 1]],
            [[3, 2], [1, 0], [2, 3], [0, 1]]
        ], dtype=torch.int32, device=device)
        self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.int32, device=device)

    def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3, 
                 weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False, 
                 cube_index_map=None, dtype=None): # dtype so that it can work with half precision.
        """
        Optionally specify 'dtype' to unify all float inputs (e.g. torch.float16).
        If 'dtype' is None, we infer from voxelgrid_vertices or default to float32.
        """
        # unify floating tensors to a chosen dtype, so it can work with half, etc:
        if dtype is None:
            if voxelgrid_vertices.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
                dtype = voxelgrid_vertices.dtype
            else:
                dtype = torch.float32

        # cast all to dtype, so that it works with half-precision:
        voxelgrid_vertices = voxelgrid_vertices.to(dtype)  # Cast the geometry to the chosen dtype
        scalar_field = scalar_field.to(dtype)
        if beta is not None:
            beta = beta.to(dtype)
        if alpha is not None:
            alpha = alpha.to(dtype)
        if gamma_f is not None:
            gamma_f = gamma_f.to(dtype)

        if voxelgrid_colors is not None:
            voxelgrid_colors = voxelgrid_colors.to(dtype)

        surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx)
        if surf_cubes.sum() == 0:
            return (
                torch.zeros((0, 3), device=self.device, dtype=dtype),
                torch.zeros((0, 3), dtype=torch.int32, device=self.device),
                torch.zeros((0), device=self.device, dtype=dtype),
                torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device, dtype=dtype) if voxelgrid_colors is not None else None
            )
        beta, alpha, gamma_f = self._normalize_weights(
            beta, alpha, gamma_f, surf_cubes, weight_scale)
        
        if voxelgrid_colors is not None:
            voxelgrid_colors = torch.sigmoid(voxelgrid_colors)

        case_ids = self._get_case_id_sparse(occ_fx8, surf_cubes, resolution, cube_index_map)
        
        surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(
            scalar_field, cube_idx, surf_cubes
        )

        vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd(
            voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field,
            case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors)

        vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate(
            scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map,
            vd_idx_map, surf_edges_mask, training, vd_color)

        return vertices, faces, L_dev, vertices_color

    def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
        """
        Regularizer L_dev as in Equation 8
        """
        dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
        mean_l2 = torch.zeros_like(vd[:, 0])
        mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
        mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
        return mad

    def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale):
        """
        Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
        """
        n_cubes = surf_cubes.shape[0]

        if beta is not None:
            beta = (torch.tanh(beta) * weight_scale + 1)
        else:
            beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)

        if alpha is not None:
            alpha = (torch.tanh(alpha) * weight_scale + 1)
        else:
            alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)

        if gamma_f is not None:
            gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2
        else:
            gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)

        return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes]

    @torch.no_grad()
    def _get_case_id(self, occ_fx8, surf_cubes, res):
        """
        Obtains the ID of topology cases based on cell corner occupancy. This function resolves the 
        ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the 
        supplementary material. It should be noted that this function assumes a regular grid.
        """
        occ_int = occ_fx8[surf_cubes].to(torch.int32)  # Convert boolean to int32
        corners_idx = self.cube_corners_idx.to(self.device).unsqueeze(0)
        case_ids = (occ_int * corners_idx).sum(-1, dtype=torch.int32) # int32 else sum defaults to int64 by default

        problem_config = self.check_table.to(self.device)[case_ids]
        to_check = problem_config[..., 0] == 1
        problem_config = problem_config[to_check]
        if not isinstance(res, (list, tuple)):
            res = [res, res, res]

        # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
        # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
        # This allows efficient checking on adjacent cubes.

        problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.int32)
        vol_idx = torch.nonzero(problem_config_full[..., 0] == 0)  # N, 3
        vol_idx_problem = vol_idx[surf_cubes][to_check]
        problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
        vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]

        within_range = (
            vol_idx_problem_adj[..., 0] >= 0) & (
            vol_idx_problem_adj[..., 0] < res[0]) & (
            vol_idx_problem_adj[..., 1] >= 0) & (
            vol_idx_problem_adj[..., 1] < res[1]) & (
            vol_idx_problem_adj[..., 2] >= 0) & (
            vol_idx_problem_adj[..., 2] < res[2])

        vol_idx_problem = vol_idx_problem[within_range]
        vol_idx_problem_adj = vol_idx_problem_adj[within_range]
        problem_config = problem_config[within_range]
        problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
                                                 vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
        # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
        to_invert = (problem_config_adj[..., 0] == 1)
        idx = torch.arange(case_ids.shape[0], dtype=torch.int32, device=self.device)[to_check][within_range][to_invert]
        case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
        return case_ids

    @torch.no_grad()
    def _get_case_id_sparse(self, occ_fx8, surf_cubes, res, cube_index_map):
        """
        Obtains the ID of topology cases based on cell corner occupancy. This function resolves the 
        ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the 
        supplementary material.
        """
        occ_int = occ_fx8[surf_cubes].to(torch.int32)  # Convert boolean to int32
        corners_idx = self.cube_corners_idx.to(self.device).unsqueeze(0)
        case_ids = (occ_int * corners_idx).sum(-1, dtype=torch.int32) # int32 else sum defaults to int64 by default
        problem_config = self.check_table.to(self.device)[case_ids]
        to_check = problem_config[..., 0] == 1
        
        problem_config = problem_config[to_check]
        problem_config_index = cube_index_map[surf_cubes][to_check]

        # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
        # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
        # This allows efficient checking on adjacent cubes.
        vol_idx = cube_index_map  # N, 3
        vol_idx_problem = vol_idx[surf_cubes][to_check]
        
        vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]

        within_range = (
            vol_idx_problem_adj[..., 0] >= 0) & (
            vol_idx_problem_adj[..., 0] < res) & (
            vol_idx_problem_adj[..., 1] >= 0) & (
            vol_idx_problem_adj[..., 1] < res) & (
            vol_idx_problem_adj[..., 2] >= 0) & (
            vol_idx_problem_adj[..., 2] < res)

        vol_idx_problem = vol_idx_problem[within_range]
        vol_idx_problem_adj = vol_idx_problem_adj[within_range]
        problem_config = problem_config[within_range]
        problem_config_index = problem_config_index[within_range]

        inverse_indices = ((vol_idx_problem_adj[..., 0] * res + vol_idx_problem_adj[..., 1]) * res + vol_idx_problem_adj[..., 2]).tolist()
        inverse_cube_index_dict = dict(zip(((problem_config_index[..., 0] * res + problem_config_index[..., 1]) * res + problem_config_index[..., 2]).tolist(), range(len(problem_config))))
        inverse_indices = torch.Tensor([inverse_cube_index_dict[inverse_idx] if inverse_idx in inverse_cube_index_dict else -1 for inverse_idx in inverse_indices]).int().to(self.device)
        problem_config = torch.cat((problem_config, torch.zeros((1, 5), dtype=problem_config.dtype, device=problem_config.device)))
        problem_config_adj = problem_config[inverse_indices]

        # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
        to_invert = (problem_config_adj[..., 0] == 1)
        idx = torch.arange(case_ids.shape[0], dtype=torch.int32, device=self.device)[to_check][within_range][to_invert]
        case_ids.index_put_((idx,), problem_config[:-1][to_invert][..., -1])

        return case_ids
    

    @torch.no_grad()
    def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes):
        """
        Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge 
        can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge 
        and marks the cube edges with this index.
        """
        occ_n = scalar_field < 0
        all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2)
        unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)

        unique_edges = unique_edges.int()
        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1

        surf_edges_mask = mask_edges[_idx_map]
        counts = counts[_idx_map]

        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.int32, device=cube_idx.device) * -1
        mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.int32, device=cube_idx.device) # int32 else arrange defaults to int64
        # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
        # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
        idx_map = mapping[_idx_map]
        surf_edges = unique_edges[mask_edges]
        return surf_edges, idx_map, counts, surf_edges_mask

    @torch.no_grad()
    def _identify_surf_cubes(self, scalar_field, cube_idx):
        """
        Identifies grid cubes that intersect with the underlying surface by checking if the signs at 
        all corners are not identical.
        """
        occ_n = scalar_field < 0
        occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8)
        _occ_sum = torch.sum(occ_fx8, -1)
        surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
        return surf_cubes, occ_fx8

    def _linear_interp(self, edges_weight, edges_x):
        """
        Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
        """
        edge_dim = edges_weight.dim() - 2
        assert edges_weight.shape[edge_dim] == 2
        edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
                                 torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)]
                                 , edge_dim)
        denominator = edges_weight.sum(edge_dim)
        ue = (edges_x * edges_weight).sum(edge_dim) / denominator
        return ue

    def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale):
        p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
        norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
        c_bx3 = c_bx3.reshape(-1, 3)
        A = norm_bxnx3
        B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)

        A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
        B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1)
        A = torch.cat([A, A_reg], 1)
        B = torch.cat([B, B_reg], 1)
        dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
        return dual_verts

    def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field,
                    case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors):
        """
        Computes the location of dual vertices as described in Section 4.2
        """
        alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
        surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
        surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
        zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
        
        if voxelgrid_colors is not None:
            C = voxelgrid_colors.shape[-1]
            surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C)
        else:
            C = None
            surf_edges_c = None
        idx_map = idx_map.reshape(-1, 12)
        num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
        edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
        
        # if color is not None:
        #     vd_color = []

        total_num_vd = 0
        vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.int64, device=self.device, requires_grad=False)

        for num in torch.unique(num_vd):
            cur_cubes = (num_vd == num)  # consider cubes with the same numbers of vd emitted (for batching)
            curr_num_vd = cur_cubes.sum() * num
            curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
            curr_edge_group_to_vd = torch.arange(
                curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
            total_num_vd += curr_num_vd
            curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
                cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)

            curr_mask = (curr_edge_group != -1)
            edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
            edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
            edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
            vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
            vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
            # if color is not None:
            #     vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3))
            
        edge_group = torch.cat(edge_group)
        edge_group_to_vd = torch.cat(edge_group_to_vd)
        edge_group_to_cube = torch.cat(edge_group_to_cube)
        vd_num_edges = torch.cat(vd_num_edges)
        vd_gamma = torch.cat(vd_gamma)
        # if color is not None:
        #     vd_color = torch.cat(vd_color)
        # else:
        #     vd_color = None

        vd = torch.zeros((total_num_vd, 3), device=self.device, dtype=beta.dtype) # dtype so that it works with half-precision
        beta_sum = torch.zeros((total_num_vd, 1), device=self.device, dtype=beta.dtype) # dtype so that it works with half

        idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)

        x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
        s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
        

        zero_crossing_group = torch.index_select(
            input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)

        alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
                                            index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
        ue_group = self._linear_interp(s_group * alpha_group, x_group)

        beta_group = torch.gather(input=beta.reshape(-1), dim=0,
                                    index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
        beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
        vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
        
        '''
        interpolate colors use the same method as dual vertices
        '''
        if voxelgrid_colors is not None and surf_edges_c is not None:
            vd_color = torch.zeros((total_num_vd, C), device=self.device, dtype=beta.dtype)
            c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C)
            uc_group = self._linear_interp(s_group * alpha_group, c_group)
            vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum
        else:
            vd_color = None
        
        L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)

        v_idx = torch.arange(vd.shape[0], dtype=torch.int64, device=self.device)  # + total_num_vd
        index = edge_group_to_cube * 12 + edge_group
        vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=index, src=v_idx[edge_group_to_vd])

        return vd, L_dev, vd_gamma, vd_idx_map, vd_color

    def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color):
        """
        Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into 
        triangles based on the gamma parameter, as described in Section 4.3.
        """
        with torch.no_grad():
            group_mask = (edge_counts == 4) & surf_edges_mask  # surface edges shared by 4 cubes.
            group = idx_map.reshape(-1)[group_mask]
            vd_idx = vd_idx_map[group_mask]
            edge_indices, indices = torch.sort(group, stable=True)
            indices = indices.to(torch.int32)
            quad_vd_idx = vd_idx[indices].reshape(-1, 4).to(torch.int32)

            # Ensure all face directions point towards the positive SDF to maintain consistent winding.
            s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
            flip_mask = s_edges[:, 0] > 0
            quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
                                     quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))

        quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
        gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2]
        gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3]
        if not training:
            mask = (gamma_02 > gamma_13)
            faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.int32, device=quad_vd_idx.device)
            faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
            faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
            faces = faces.reshape(-1, 3)
        else:
            vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
            vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2
            vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2
            weight_sum = (gamma_02 + gamma_13) + 1e-8
            vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
            
            if vd_color is not None:
                color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1])
                color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2
                color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2
                color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
                vd_color = torch.cat([vd_color, color_center])
            
            
            vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
            vd = torch.cat([vd, vd_center])
            faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
            faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
        return vd, faces, s_edges, edge_indices, vd_color
